def getTracesWithNodeInfo(G,nodeCommunity=False):
if not nodeCommunity:
nodeCommunity=[1 for n in range(G.number_of_nodes()+1)]
spring_3D = nx.spring_layout(G,dim=3, seed=18)
#we need to seperate the X,Y,Z coordinates for Plotly
x_nodes = [spring_3D[i][0] for i in range(G.number_of_nodes())]# x-coordinates of nodes
y_nodes = [spring_3D[i][1] for i in range(G.number_of_nodes())]# y-coordinates
z_nodes = [spring_3D[i][2] for i in range(G.number_of_nodes())]# z-coordinates
#We also need a list of edges to include in the plot
edge_list = G.edges()
#we need to create lists that contain the starting and ending coordinates of each edge.
x_edges=[]
y_edges=[]
z_edges=[]
#need to fill these with all of the coordiates
for edge in edge_list:
#format: [beginning,ending,None]
x_coords = [spring_3D[edge[0]][0],spring_3D[edge[1]][0],None]
x_edges += x_coords
y_coords = [spring_3D[edge[0]][1],spring_3D[edge[1]][1],None]
y_edges += y_coords
z_coords = [spring_3D[edge[0]][2],spring_3D[edge[1]][2],None]
z_edges += z_coords
#create a trace for the edges
trace_edges = plot3D.Scatter3d(x=x_edges,
y=y_edges,
z=z_edges,
mode='lines',
line=dict(color='black', width=2),
hoverinfo='none')
#create a trace for the nodes
trace_nodes = plot3D.Scatter3d(x=x_nodes,
y=y_nodes,
z=z_nodes,
mode='markers',
marker=dict(symbol='circle',
size=10,
color=nodeCommunity, #color the nodes according to their community
colorscale='Viridis', #either green or mageneta
line=dict(color='black', width=0.5)),
text=list(G.nodes()),
hoverinfo='text')
return [trace_edges,trace_nodes]
def getLayout(plotName):
#we need to set the axis for the plot
axis = dict(showbackground=False,
showline=False,
zeroline=False,
showgrid=False,
showticklabels=False,
showspikes=False,
title='')
#also need to create the layout for our plot
layout = plot3D.Layout(title=plotName,
showlegend=False,
width=1024,
height=980,
scene=dict(xaxis=dict(axis),
yaxis=dict(axis),
zaxis=dict(axis),
),
hovermode='closest')
return layout
def plot3Dnetwork(G,name,nodeCommunity=False):
data = getTracesWithNodeInfo(G,nodeCommunity)
fig = plot3D.Figure(data=data, layout=getLayout(name))
fig.show()